{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['finetune-t5-base-standard-bahasa-cased/checkpoint-360000',\n",
       " 'finetune-t5-base-standard-bahasa-cased/checkpoint-370000',\n",
       " 'finetune-t5-base-standard-bahasa-cased/checkpoint-380000']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "checkpoints = sorted(glob('finetune-t5-base-standard-bahasa-cased/checkpoint-*'))\n",
    "checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-small-standard-bahasa-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = T5ForConditionalGeneration.from_pretrained(checkpoints[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = \"Kabinet bersetuju mewujudkan satu jawatan kuasa dalaman untuk menyiasat isu 'perbalahan' antara Jabatan Perkhidmatan Awam (JPA) dan Kesatuan Perkhidmatan Imigresen Semenanjung Malaysia (KPISM)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[\"Kabinet bersetuju mewujudkan jawatan kuasa dalaman untuk menyiasat isu 'perbalahan' antara Jabatan Perkhidmatan Awam (JPA) dan Kesatuan Perkhidmatan Imigresen Semenanjung Malaysia (KPISM).\",\n",
       " \"Kabinet bersetuju untuk mewujudkan jawatan kuasa dalaman untuk menyiasat isu 'perbalahan' antara Jabatan Perkhidmatan Awam (JPA) dan Kesatuan Perkhidmatan Imigresen Semenanjung Malaysia (KPISM).\",\n",
       " \"Kabinet bersetuju untuk mewujudkan satu jawatan kuasa dalaman untuk menyiasat isu 'perbalahan' antara Jabatan Perkhidmatan Awam (JPA) dan Kesatuan Perkhidmatan Imigresen Semenanjung Malaysia (KPISM).\",\n",
       " \"Kabinet bersetuju untuk mewujudkan jawatan berkuasa dalaman untuk menyiasat isu 'perbalahan' antara Jabatan Perkhidmatan Awam (JPA) dan Kesatuan Perkhidmatan Imigresen Semenanjung Malaysia (KPISM).\",\n",
       " \"Kabinet bersetuju untuk mewujudkan jawatan dalaman kuasa yang bertujuan untuk menyiasat 'perbalahan' antara Jabatan Perkhidmatan Awam (JPA) dan Kesatuan Perkhidmatan Imigresen Semenanjung Malaysia (KPISM).\"]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids = tokenizer.encode(f'parafrasa: {s}', return_tensors = 'pt')\n",
    "outputs = model.generate(input_ids, do_sample=True,\n",
    "    max_length=256,\n",
    "    top_k=50,\n",
    "    top_p=0.95,\n",
    "    early_stopping=True,\n",
    "    num_return_sequences=5)\n",
    "tokenizer.batch_decode(outputs, skip_special_tokens = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.8/site-packages/transformers/utils/hub.py:651: UserWarning: The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`).\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/mesolitica/finetune-paraphrase-t5-base-standard-bahasa-cased/commit/67a4f28830f1fe2aaf69097200d9663528479192', commit_message='Upload T5ForConditionalGeneration', commit_description='', oid='67a4f28830f1fe2aaf69097200d9663528479192', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub('finetune-paraphrase-t5-base-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/mesolitica/finetune-paraphrase-t5-base-standard-bahasa-cased/commit/83df426b9bfd417d29f39055add9f714bcbe1dbd', commit_message='Upload tokenizer', commit_description='', oid='83df426b9bfd417d29f39055add9f714bcbe1dbd', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.push_to_hub('finetune-paraphrase-t5-base-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sacrebleu.metrics import BLEU, CHRF, TER\n",
    "\n",
    "bleu = BLEU()\n",
    "chrf = CHRF(word_order = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "test = []\n",
    "with open('shuffled-test.json') as fopen:\n",
    "    for l in fopen:\n",
    "        test.append(json.loads(l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 5544/5544 [51:02<00:00,  1.81it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "batch_size = 1\n",
    "\n",
    "results = []\n",
    "for i in tqdm(range(0, len(test), batch_size)):\n",
    "    input_ids = [{'input_ids': tokenizer.encode(f\"parafrasa: {s['translation']['src']}\", return_tensors = 'pt')[0]} for s in test[i:i + batch_size]]\n",
    "    padded = tokenizer.pad(input_ids, padding = 'longest')\n",
    "    outputs = model.generate(**padded, max_length = 256)\n",
    "    for o in outputs:\n",
    "        results.append(tokenizer.decode(o, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_left, filtered_right = [], []\n",
    "for no, r in enumerate(results):\n",
    "    if len(r):\n",
    "        filtered_left.append(r)\n",
    "        filtered_right.append(test[no]['translation']['tgt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'BLEU',\n",
       " 'score': 35.95965899952292,\n",
       " '_mean': -1.0,\n",
       " '_ci': -1.0,\n",
       " '_verbose': '61.7/41.3/32.0/25.8 (BP = 0.944 ratio = 0.946 hyp_len = 95593 ref_len = 101064)',\n",
       " 'bp': 0.9443747373110852,\n",
       " 'counts': [59014, 37157, 27016, 20383],\n",
       " 'totals': [95593, 90049, 84505, 78961],\n",
       " 'sys_len': 95593,\n",
       " 'ref_len': 101064,\n",
       " 'precisions': [61.73464584226878,\n",
       "  41.263090095392506,\n",
       "  31.969705934560086,\n",
       "  25.81400944770203],\n",
       " 'prec_str': '61.7/41.3/32.0/25.8',\n",
       " 'ratio': 0.9458659859099184}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "refs = [filtered_right]\n",
    "sys = filtered_left\n",
    "r = bleu.corpus_score(sys, refs)\n",
    "r.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
